from typing import Optional, Dict
import io
from functools import partial
from PIL import Image as PILImage
import matplotlib
import matplotlib.ticker as tkr
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
import wandb
import warnings

import torch


def force_aspect(ax, aspect=1):
    im = ax.get_images()
    extent = im[0].get_extent()
    ax.set_aspect(abs((extent[1] - extent[0]) / (extent[3] - extent[2])) / aspect)


def plt_to_wandb_image(fig):
    buf = io.BytesIO()
    fig.savefig(buf, bbox_inches="tight")
    buf.seek(0)
    img = PILImage.open(buf)
    plt.close(fig)
    return wandb.Image(img)


def distribution_5D(x, **kwargs):
    _ = kwargs
    labels = [r"v_{par}", r"v_{\mu}", r"s", r"k_x", r"k_y"]

    if isinstance(x, torch.Tensor):
        x = x.cpu().detach().numpy()

    comb = torch.combinations(torch.arange(5), 2).tolist()

    fig, ax = plt.subplots(5, 5, figsize=(20, 20))
    for i in range(5):
        for j in range(5):
            if [i, j] not in comb:
                ax[i, j].remove()

    c_map = matplotlib.colormaps["coolwarm"]
    c_map.set_bad("k")

    imin = -1
    for i, j in comb:
        other = tuple([o for o in range(5) if o != i and o != j])
        xx = x[0].std(other)
        xx[xx == 0] = np.nan
        ax[i, j].matshow(xx, cmap=c_map)

        if i > imin:
            ax[i, j].set_ylabel(rf"${labels[i]}$", fontsize=20)
            ax[i, j].set_xlabel(rf"${labels[j]}$", fontsize=20)
            imin = i

        force_aspect(ax[i, j])

    return plt_to_wandb_image(fig)


def plot4x4_sided(
    x1: torch.Tensor,
    x2: Optional[torch.Tensor] = None,
    title: str = "",
    plot_type: str = "mean_comparison",
):
    if x2 is None:
        assert plot_type == "mean_and_std"

    labels = [r"v_{par}", r"v_{\mu}", r"s", r"k_x", r"k_y"]
    comb = torch.combinations(torch.arange(5), 2).tolist()

    fig, ax = plt.subplots(5, 5, figsize=(30, 14))
    for i in range(5):
        for j in range(5):
            if j == 0:
                ax[i, j].remove()
                continue
            if i == 4:
                ax[i, j].remove()
                continue
            ax_ij = ax[i, j]
            ax_ij.set_frame_on(False)
            ax_ij.tick_params(labelleft=False, labelbottom=False)
            ax_ij.set_xticks([])
            ax_ij.set_yticks([])

    fig.suptitle(title)
    c_map = matplotlib.colormaps["RdBu"]
    c_map.set_bad("k")

    for i, j in comb:
        other = tuple([o for o in range(5) if o != i and o != j])

        if "comparison" in plot_type:
            label1, label2 = r"PRED", r"GT"
            if "mean" in plot_type:
                x1_plot = x1[0].mean(other)
                x2_plot = x2[0].mean(other)
            if "std" in plot_type:
                x1_plot = x1[0].std(other)
                x2_plot = x2[0].std(other)
            if "slice" in plot_type:
                x1_plot = (
                    torch.tensor(x1[0]).permute(i, j, *other).numpy()[:, :, 0, 0, 0]
                )
                x2_plot = (
                    torch.tensor(x2[0]).permute(i, j, *other).numpy()[:, :, 0, 0, 0]
                )
        if plot_type == "mean_and_std":
            x1_plot = x1[0].mean(other)
            x2_plot = x1[0].std(other)
            label1, label2 = r"mean", r"std"

        ax_ij = ax[i, j]
        pos = ax_ij.get_position()

        # create two new axes within the same space as the original subplot
        plot_width = 0.475 * pos.width
        left_margin = 0.0 * pos.width
        x_left_1 = pos.x0 + left_margin
        x_left_2 = x_left_1 + plot_width
        y = pos.y0
        h = pos.height
        ax1 = fig.add_axes([x_left_1, y, plot_width, h])
        ax2 = fig.add_axes([x_left_2, y, plot_width, h])

        if plot_type == "mean_and_std":
            vmin1 = x1_plot.min()
            vmin2 = x2_plot.min()
            vmax1 = x1_plot.max()
            vmax2 = x2_plot.max()
        else:
            vmin1 = vmin2 = min(x1_plot.min(), x2_plot.min())
            vmax1 = vmax2 = max(x1_plot.max(), x2_plot.max())

        im1 = ax1.matshow(x1_plot, cmap=c_map, vmin=vmin1, vmax=vmax1)
        ax2.matshow(x2_plot, cmap=c_map, vmin=vmin2, vmax=vmax2)

        if plot_type != "mean_and_std":
            cbar = fig.colorbar(
                im1,
                ax=[ax_ij],
                format=tkr.FormatStrFormatter("%.2g"),
                pad=0,
                fraction=0.05,
            )
            cbar.set_ticks([vmin1, (vmin1 + vmax1) / 2, vmax1])
            cbar.ax.tick_params(labelsize=12)

        if i == 0:
            ax1.set_title(label1, fontsize=24)
            ax2.set_title(label2, fontsize=24)

        if j == 1 or (i == 1 and j == 2) or (i == 2 and j == 3) or (i == 3 and j == 4):
            ax_ij.set_ylabel(rf"${labels[i]}$", fontsize=14)

        if i == 3 or j == 1 or (i == 1 and j == 2) or (i == 2 and j == 3):
            ax_ij.set_xlabel(rf"${labels[j]}$", fontsize=14)

        ax1.set_xticks([])
        ax1.set_yticks([])
        ax2.set_xticks([])
        ax2.set_yticks([])
        ax1.tick_params(labelleft=False, labelbottom=False)
        ax2.tick_params(labelleft=False, labelbottom=False)
        force_aspect(ax1)
        force_aspect(ax2)

    return plt_to_wandb_image(fig)


def mse_time_histogram(losses):
    fig, ax = plt.subplots(1, 1, figsize=(10, 5))
    times = sorted(losses.keys())
    losses_mean = [np.mean(losses[t]) for t in times]
    losses_std = [np.std(losses[t]) for t in times]
    # Bar plot with error bars
    ax.bar(times, losses_mean, yerr=losses_std, alpha=0.7, capsize=5, color="blue")
    ax.set_xlabel("Time Step")
    ax.set_ylabel("Mean Squared Error")
    ax.set_title("MSE by Time Step")
    ax.grid(True)
    img = wandb.Image(fig)
    plt.close(fig)
    return img


def radially_averaged_power_spectrum_nd(image):
    warnings.warn("radially_averaged_power_spectrum_nd is wrong!")
    image = image - image.mean()
    fourier_transform = np.fft.fftn(image)
    fourier_transform_shifted = np.fft.fftshift(fourier_transform)
    power_spectrum = np.abs(fourier_transform_shifted) ** 2
    # Create a grid of radial distances from the center
    shape = image.shape
    center = np.array(shape) // 2
    indices = np.indices(shape)
    r = np.sqrt(((indices - center.reshape((-1,) + (1,) * len(shape))) ** 2).sum(0))
    r = r.astype(int)
    # Sum the power spectrum values at each radius
    radial_sum = np.bincount(r.ravel(), power_spectrum.ravel())
    # Count the number of pixels at each radius
    radial_count = np.bincount(r.ravel())
    return radial_sum / radial_count


def plot_5D_raspec(x, x2):
    fig, ax = plt.subplots(1, 1, figsize=(4, 3), layout="tight")
    raspec = radially_averaged_power_spectrum_nd(x.cpu().detach().numpy())
    gt_raspec = radially_averaged_power_spectrum_nd(x2.cpu().detach().numpy())
    ax.loglog(raspec, label="Pred spec", c="r")
    ax.loglog(gt_raspec, label="GT spec", c="k")
    ax.set_xlabel("Freq")
    ax.set_ylabel("A")
    ax.grid(True)
    return plt_to_wandb_image(fig)


def plot_4x4_2D_raspec(x1, x2=None, **kwargs):
    from pysteps.utils.spectral import rapsd

    _ = kwargs
    labels = [r"v_{par}", r"v_{\mu}", r"s", r"k_x", r"k_y"]

    comb = torch.combinations(torch.arange(5), 2).tolist()

    fig, ax = plt.subplots(5, 5, figsize=(20, 20))
    for i in range(5):
        for j in range(5):
            if [i, j] not in comb:
                ax[i, j].remove()

    imin = -1
    for i, j in comb:
        other = tuple([o for o in range(5) if o != i and o != j])
        xx = np.stack(
            [x1[0].permute(i, j, *other).numpy(), x1[1].permute(i, j, *other).numpy()],
            axis=-1,
        )
        xx = np.nan_to_num(xx)
        xx = np.complex64(xx)

        slices = [tuple(np.random.randint(0, dim, size=100)) for dim in xx.shape[2:]]
        slices = list(zip(*slices))
        # slices = np.ndindex(*xx.shape[2:])  # all slices

        # radially averaged power spectrum for each slice
        xx_raspec = [rapsd(xx[:, :, *sl], fft_method=np.fft) for sl in slices]
        xx_raspec_avg = np.mean(xx_raspec, axis=0)

        ax[i, j].loglog(xx_raspec_avg, label="Pred spec", c="r", lw=3)
        ax[i, j].grid(True)

        if x2 is not None:
            yy = np.stack(
                [
                    x2[0].permute(i, j, *other).numpy(),
                    x2[1].permute(i, j, *other).numpy(),
                ],
                axis=-1,
            )
            yy = np.complex64(yy)
            yy_raspec = [rapsd(yy[:, :, *sl], fft_method=np.fft) for sl in slices]
            yy_raspec_avg = np.mean(yy_raspec, axis=0)
            ax[i, j].loglog(yy_raspec_avg, label="GT spec", c="k", lw=3)

        if i > imin:
            ax[i, j].set_ylabel(rf"${labels[i]}$ (A)", fontsize=20)
            ax[i, j].set_xlabel(rf"${labels[j]}$ ($\phi$)", fontsize=20)
            imin = i

    return plt_to_wandb_image(fig)


def plot_potentials(x1, x2):
    from matplotlib import colormaps

    c_map = colormaps["plasma"]

    fig, ax = plt.subplots(2, 1, figsize=(10, 5))
    fig.subplots_adjust(wspace=0.05)

    # select only real part if we predicted both real/imag parts of phi
    x1 = x1[0] if x1.ndim > 3 else x1
    x2 = x2[0] if x2.ndim > 3 else x2
    ax[0].matshow(x1.squeeze()[:, 8, :].T, cmap=c_map)
    ax[0].set_title(r"$\phi_{pred}$", fontsize=24)
    ax[0].set_ylabel(r"$y_{\phi}$", fontsize=20)
    ax[0].set_xticks([])
    ax[0].set_yticks([])

    ax[1].matshow(x2.squeeze()[:, 8, :].T, cmap=c_map)
    ax[1].set_title(r"$\phi_{GT}$", fontsize=24)
    ax[1].set_xlabel(r"$x_{\phi}$", fontsize=20)
    ax[1].set_ylabel(r"$y_{\phi}$", fontsize=20)
    ax[1].set_xticks([])
    ax[1].set_yticks([])

    return plt_to_wandb_image(fig)


def flux_distribution(x1: torch.Tensor, x2: torch.Tensor):
    x1 = x1.numpy()
    x2 = x2.item()
    conf_interval = 1.96 * x1.std()

    y_min = min(x1.min() - 0.8 * conf_interval, x2) * 0.9
    y_max = max(x1.max() + 0.8 * conf_interval, x2) * 1.1

    fig = plt.figure(figsize=(10, 6))
    gs = gridspec.GridSpec(1, 2, width_ratios=[1, 4], wspace=0.05)

    # histogram (left)
    ax_hist = fig.add_subplot(gs[0])
    ax_hist.hist(
        x1,
        bins=15,
        orientation="horizontal",
        color="#418fc4",
        edgecolor="black",
        alpha=0.8,
    )
    ax_hist.axhline(x2, color="#d1495b", linestyle="--", linewidth=2)
    ax_hist.axhline(x1.mean(), color="#386fa4", linestyle="-", linewidth=1)
    ax_hist.set_xlim(right=ax_hist.get_xlim()[1] * 1.2)
    ax_hist.set_ylim(y_min, y_max)
    ax_hist.invert_xaxis()

    # confidence intervals (right)
    ax_scatter = fig.add_subplot(gs[1], sharey=ax_hist)
    ax_scatter.plot(x1, "o", color="#386fa4", label="Predictions")
    ax_scatter.axhline(x2, color="#d1495b", linestyle="--", linewidth=2)
    ax_scatter.axhline(x1.mean(), color="#386fa4", linestyle="-", linewidth=1)
    ax_scatter.fill_between(
        range(len(x1)),
        x1.mean() - conf_interval,
        x1.mean() + conf_interval,
        color="#59a5d8",
        alpha=0.2,
        label="95% CI",
    )
    ax_scatter.set_xlim(left=-0.1)
    ax_scatter.set_xlim(right=len(x1) - 0.9)

    ax_hist.tick_params(axis="x", bottom=True, labelbottom=True)
    ax_hist.tick_params(axis="y", left=True, labelleft=True)
    ax_scatter.tick_params(axis="y", left=False, labelleft=False)

    return plt_to_wandb_image(fig)


def generate_val_plots(
    rollout,
    gt,
    timestep: torch.Tensor,
    conditioning: Dict[str, float],
    stage: str = "autoencoder",
    phase: Optional[str] = None,
):
    plots = {}
    val_plots_dict = {
        "df": {
            f"pred (T={timestep[0].item():.2f}, {phase})": plot4x4_sided,
            # f"std (T={timestep[0].item():.2f}, {phase})": distribution_5D,
            # f"2D RA spectrum (T={timestep[0].item():.2f}, {phase})": plot_4x4_2D_raspec,
        },
        "phi_int": {
            f"Integrated potentials (T={timestep[0].item():.2f}, {phase})": plot_potentials,
        },
        "flux_int": {"flux distribution": flux_distribution},
    }

    # del val_plots_dict["phi_int"]
    del val_plots_dict["flux_int"]

    for key in rollout.keys():
        if key not in val_plots_dict:
            # skip flux for autoencoder
            continue

        gt_key = key
        if "int" in key:
            gt_key = key.replace("_int", "")
        if gt_key == "flux":
            gt_key = f"avg_{gt_key}"

        x = rollout[key].clone()
        y = gt[gt_key].clone()

        if key == "df":
            if x.shape[0] != 2:
                # separate zonal flow, sum and recompose
                x = torch.cat(
                    [
                        x[0::2].sum(axis=0, keepdims=True),
                        x[1::2].sum(axis=0, keepdims=True),
                    ],
                    dim=0,
                )
            if y.shape[0] != 2:
                y = torch.cat(
                    [
                        y[0::2].sum(axis=0, keepdims=True),
                        y[1::2].sum(axis=0, keepdims=True),
                    ],
                    dim=0,
                )

        for name, plot_fn in val_plots_dict[key].items():
            plots[name] = plot_fn(x, x2=y)

    return plots
